Skip to content

[Bugfix][CPU] Fix RotaryEmbedding fallback causing gibberish with --enforce-eager#31643

Merged
bigPYJ1151 merged 4 commits intovllm-project:mainfrom
ricky-chaoju:fix/cpu-enforce-eager-gibberish-output
Jan 5, 2026
Merged

[Bugfix][CPU] Fix RotaryEmbedding fallback causing gibberish with --enforce-eager#31643
bigPYJ1151 merged 4 commits intovllm-project:mainfrom
ricky-chaoju:fix/cpu-enforce-eager-gibberish-output

Conversation

@ricky-chaoju
Copy link
Copy Markdown
Contributor

@ricky-chaoju ricky-chaoju commented Jan 3, 2026

Summary

Fix gibberish output on CPU backend when --enforce-eager is enabled.

resolve #31626

When running vLLM on the CPU backend with --enforce-eager, models may produce incoherent or repetitive outputs.
This happens because enforce_eager=True sets custom_ops="all", enabling CustomOp dispatch on CPU. In this mode, CustomOp.dispatch_forward() selects forward_cpu() implementations when available.

Several CustomOp subclasses did not define forward_cpu(), causing them to fall back to the base class behavior, which delegates to forward_cuda(). On CPU, this path invokes the C++ CPU kernels, whose behavior diverges from the PyTorch native implementations in certain cases, leading to incorrect computations and degraded output quality.

Root cause

Missing forward_cpu() implementations in:

  • RotaryEmbedding
  • RMSNorm
  • GemmaRMSNorm
  • RMSNormGated

Fix

Add explicit forward_cpu() methods that delegate to forward_native().
This ensures that, when running on CPU with custom ops enabled, these layers consistently use the PyTorch native implementations, restoring correct behavior while keeping the existing execution model unchanged.

Test Plan

Tested with Qwen3-0.6B on CPU with --enforce-eager:

vllm serve "Qwen/Qwen3-0.6B" --enforce-eager --max-model-len 4096

Before / After (CPU, --enforce-eager)

Before:
"you so you so so so you so so so so"
"ellaneousellaneousellaneous"

After:
"27 member states, and each member state has a population of..."
"I'm here to help. How can I assist you?"

Copy link
Copy Markdown
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request addresses a critical bug that causes models to produce incoherent output when running on the CPU backend with --enforce-eager. The root cause was correctly identified: several CustomOp subclasses lacked a forward_cpu implementation, causing a fallback to C++ kernels that have behavioral inconsistencies with their PyTorch native counterparts. The fix, which involves adding explicit forward_cpu methods to RotaryEmbedding, RMSNorm, GemmaRMSNorm, and RMSNormGated that delegate to forward_native, is sound and directly resolves the issue. The changes are consistent, well-targeted, and ensure that the correct, native PyTorch implementations are used on the CPU, restoring model correctness. The implementation is clean and follows existing patterns in the codebase.

  Add explicit forward_cpu methods to CustomOp subclasses that delegate
  to forward_native, ensuring CPU backend uses PyTorch native implementation
  instead of buggy CPU C++ kernels when custom_ops='all'.

  Classes fixed:
  - RotaryEmbedding (rotary_embedding/base.py)
  - RMSNorm (layernorm.py)
  - GemmaRMSNorm (layernorm.py)
  - RMSNormGated (layernorm.py)

  Fixes vllm-project#31626

Signed-off-by: rickychen-infinirc <ricky.chen@infinirc.com>
Add explicit forward_cpu methods to CustomOp subclasses that delegate
to forward_native, ensuring CPU backend uses PyTorch native implementation
instead of buggy CPU C++ kernels when custom_ops='all'.

Classes fixed:
- RotaryEmbedding (rotary_embedding/base.py)
- RMSNorm (layernorm.py)
- GemmaRMSNorm (layernorm.py)
- RMSNormGated (layernorm.py)

Fixes vllm-project#31626

Signed-off-by: rickychen-infinirc <ricky.chen@infinirc.com>
@ricky-chaoju ricky-chaoju force-pushed the fix/cpu-enforce-eager-gibberish-output branch from 6f24b52 to b75b49a Compare January 3, 2026 11:00
@chaunceyjiang
Copy link
Copy Markdown
Collaborator

/cc @ProExpertProg PTAL.

@bigPYJ1151
Copy link
Copy Markdown
Member

Hi @rickychen-infinirc Thanks for the catch. The root cause is RMSNorm accepts non-contiguous inputs after #28103 but we didn't added it to the CPU kernels.

It's okay to fallback most of custom ops to torch native implementations because they are not performance-critical on CPU and can be compiled by torch compile in most cases. We can dispatch them to forward_native at:

def forward_cpu(self, *args, **kwargs):
# By default, we assume that CPU ops are compatible with CUDA ops.
return self.forward_cuda(*args, **kwargs)

One special case is RoPE, I prefer to use the CPU custom kernel in eager mode.

- Change CustomOp.forward_cpu() default to forward_native instead of
  forward_cuda, as most CPU custom kernels are not performance-critical
  and can have compatibility issues
- Remove redundant forward_cpu() from RMSNorm, GemmaRMSNorm, RMSNormGated
  since they now inherit the base class behavior

Signed-off-by: rickychen-infinirc <ricky.chen@infinirc.com>
@ricky-chaoju ricky-chaoju force-pushed the fix/cpu-enforce-eager-gibberish-output branch from 7afc967 to 0bd69ef Compare January 5, 2026 09:07
@ricky-chaoju
Copy link
Copy Markdown
Contributor Author

@bigPYJ1151 Thanks for the review and the clarification on the root cause!

I've updated the PR based on your suggestion:

  1. Changed CustomOp.forward_cpu() default to dispatch to forward_native() instead of forward_cuda()
  2. Removed the redundant forward_cpu() methods from RMSNorm, GemmaRMSNorm, and RMSNormGated - they now inherit the base class behavior
  3. Kept RotaryEmbedding.forward_cpu() using the CPU custom kernel (ops.rotary_embedding) as you mentioned it's more performance-sensitive in eager mode

Tested with Qwen3-0.6B on CPU with --enforce-eager and the output is now correct.

@bigPYJ1151 bigPYJ1151 added the ready ONLY add when PR is ready to merge/full CI is needed label Jan 5, 2026
@mergify
Copy link
Copy Markdown

mergify bot commented Jan 5, 2026

Hi @rickychen-infinirc, the pre-commit checks have failed. Please run:

uv pip install pre-commit
pre-commit install
pre-commit run --all-files

Then, commit the changes and push to your branch.

For future commits, pre-commit will run automatically on changed files before each commit.

Tip

Is mypy or markdownlint failing?
mypy and markdownlint are run differently in CI. If the failure is related to either of these checks, please use the following commands to run them locally:
# For mypy (substitute "3.10" with the failing version if needed)
pre-commit run --hook-stage manual mypy-3.10
# For markdownlint
pre-commit run --hook-stage manual markdownlint

@ricky-chaoju ricky-chaoju force-pushed the fix/cpu-enforce-eager-gibberish-output branch from 86ce853 to 0bd69ef Compare January 5, 2026 12:43
@mergify
Copy link
Copy Markdown

mergify bot commented Jan 5, 2026

Hi @rickychen-infinirc, the pre-commit checks have failed. Please run:

uv pip install pre-commit
pre-commit install
pre-commit run --all-files

Then, commit the changes and push to your branch.

For future commits, pre-commit will run automatically on changed files before each commit.

Tip

Is mypy or markdownlint failing?
mypy and markdownlint are run differently in CI. If the failure is related to either of these checks, please use the following commands to run them locally:
# For mypy (substitute "3.10" with the failing version if needed)
pre-commit run --hook-stage manual mypy-3.10
# For markdownlint
pre-commit run --hook-stage manual markdownlint

Signed-off-by: rickychen-infinirc <ricky.chen@infinirc.com>
@bigPYJ1151 bigPYJ1151 merged commit c455b77 into vllm-project:main Jan 5, 2026
47 checks passed
@ricky-chaoju ricky-chaoju deleted the fix/cpu-enforce-eager-gibberish-output branch January 6, 2026 11:01
LucasWilkinson pushed a commit to neuralmagic/vllm that referenced this pull request Jan 6, 2026
…nforce-eager (vllm-project#31643)

Signed-off-by: rickychen-infinirc <ricky.chen@infinirc.com>
yugong333 pushed a commit to yugong333/vllm that referenced this pull request Jan 9, 2026
…nforce-eager (vllm-project#31643)

Signed-off-by: rickychen-infinirc <ricky.chen@infinirc.com>
akh64bit pushed a commit to akh64bit/vllm that referenced this pull request Jan 16, 2026
…nforce-eager (vllm-project#31643)

Signed-off-by: rickychen-infinirc <ricky.chen@infinirc.com>
dsuhinin pushed a commit to dsuhinin/vllm that referenced this pull request Jan 21, 2026
…nforce-eager (vllm-project#31643)

Signed-off-by: rickychen-infinirc <ricky.chen@infinirc.com>
Signed-off-by: dsuhinin <suhinin.dmitriy@gmail.com>
ItzDEXX pushed a commit to ItzDEXX/vllm that referenced this pull request Feb 19, 2026
…nforce-eager (vllm-project#31643)

Signed-off-by: rickychen-infinirc <ricky.chen@infinirc.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

ready ONLY add when PR is ready to merge/full CI is needed

Projects

None yet

Development

Successfully merging this pull request may close these issues.

[Bug][CPU Backend]: Gibberish output on CPU backend when --enforce-eager is enabled (Qwen3-0.6B)

4 participants